{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1"
   },
   "source": [
    "# A Masked Language Model for Project CodeNet\n",
    "\n",
    "> Copyright (c) 2021 International Business Machines Corporation  \n",
    "Prepared by [Geert Janssen](geert@us.ibm.com>)\n",
    "\n",
    "## Introduction\n",
    "\n",
    "This experiment investigates whether a popular attention model to\n",
    "construct a masked language model (MLM) can be used for source code\n",
    "instead of natural language sentences. We here closely follow the\n",
    "approach by Ankur Singh documented in his\n",
    "[blog](https://keras.io/examples/nlp/masked_language_modeling).\n",
    "\n",
    "The goal of the model is to be able to infer the correct token for a\n",
    "masked-out token at an arbitrary position in the source text.\n",
    "We will use the special token literal `[mask]` to represent the masked\n",
    "out token. We assume that in the training and test sets precisely one token\n",
    "is randomly masked per sample. The original token at that position is\n",
    "then the golden label, or ground truth."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "cellView": "form",
    "id": "cjzhCC12YNfP"
   },
   "outputs": [],
   "source": [
    "#@title Imports\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.layers.experimental.preprocessing import TextVectorization\n",
    "from tensorflow.keras import Sequential, Model, losses, metrics, optimizers\n",
    "from tensorflow.keras import callbacks\n",
    "from tensorflow.keras.layers import MultiHeadAttention, LayerNormalization\n",
    "from tensorflow.keras.layers import Dense, Dropout, Input, Embedding\n",
    "from dataclasses import dataclass\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import glob\n",
    "import random\n",
    "import tarfile\n",
    "import requests\n",
    "import os\n",
    "import shutil"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2"
   },
   "source": [
    "## Dataset\n",
    "\n",
    "The Project CodeNet dataset consist of a large collection (close to 14 million) of\n",
    "submissions in various programming languages to problems posed on\n",
    "online judging sites. The submissions are typically small, complete\n",
    "programs in a single source file. There are 1000s of problems in this dataset.\n",
    "\n",
    "We extract a selection of C programming language files from the\n",
    "Project CodeNet dataset for training and evaluation:\n",
    "\n",
    "| Aspect              | Value |\n",
    "| ------------------- | ----- |\n",
    "| purpose             | training |\n",
    "| submission status   | Accepted |\n",
    "| smallest size       | 200 bytes |\n",
    "| largest size        | 500 bytes |\n",
    "| samples per problem | at most 100 |\n",
    "| problems            | 0-3417 |\n",
    "| total samples       | 50,000 |\n",
    "\n",
    "Notice that the training and evaluation are derived from\n",
    "non-overlapping sets of problems. It might also be interesting to see\n",
    "what happens when a different split is made, e.g. by selecting 110\n",
    "submissions from each problem and using 100 for training and the rest for evaluation.\n",
    "\n",
    "| Aspect              | Value |\n",
    "| ------------------- |------ |\n",
    "| purpose             | evaluation |\n",
    "| submission status   | Accepted |\n",
    "| smallest size       | 200 bytes |\n",
    "| largest size        | 500 bytes |\n",
    "| samples per problem | at most 100 |\n",
    "| problems            | 3418-3636 |\n",
    "| total samples       | 5,000 |"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download the subset of the Project CodeNet data described above.\n",
    "file_name = \"Project_CodeNet_MLM.tar.gz\"\n",
    "data_url = f\"https://dax-cdn.cdn.appdomain.cloud/dax-project-codenet/1.0.0/{file_name}\"\n",
    "\n",
    "# Download tar archive to local disk\n",
    "if os.path.exists(file_name):\n",
    "    os.remove(file_name) \n",
    "with open(file_name, \"wb\") as f:\n",
    "    f.write(requests.get(data_url).content)\n",
    "    \n",
    "# Extract contents of archive to local disk\n",
    "if os.path.exists(\"tokens\"):\n",
    "    shutil.rmtree(\"tokens\")    \n",
    "with tarfile.open(file_name) as f:\n",
    "    f.extractall()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3"
   },
   "source": [
    "## Data preparation\n",
    "\n",
    "Each C file is tokenized into a vocabulary of 414 distinct tokens:\n",
    "\n",
    "| Type           | Count | Description |\n",
    "| -------------- | ----: | -- |\n",
    "|the keyword     |    95 | all C++20 reserved words |\n",
    "|the function    |   279 | function names in common header files |\n",
    "|the identifier  |    18 | standard identifiers, like stderr, etc. |\n",
    "|the punctuator  |    15 | small set of punctuation symbols |\n",
    "|# or ##         |     2 | the 2 C preprocessor symbols |\n",
    "|the token class |     5 | one of: id, number, operator, character, string |\n",
    "\n",
    "By _the keyword_, _the function_ and so on, we mean the actual keyword\n",
    "or function literal, like `while` for a keyword and `strlen` for a function.\n",
    "The tokens are output on a single line separated by spaces.\n",
    "It turns out that our training set overall uses some 200 out of the 414\n",
    "possible tokens; not all keywords and standard functions are used presumably.\n",
    "\n",
    "This code snippet:\n",
    "```C\n",
    "for (i = 0; i < strlen(s); i++) {}\n",
    "```\n",
    "\n",
    "will be converted to:\n",
    "```C\n",
    "for ( id = number ; id < strlen ( id ) ; id operator ) { }\n",
    "```\n",
    "\n",
    "The tokenized source files are read into a pandas dataframe and\n",
    "processed by the Keras `TextVectorization` layer to extract a vocabulary\n",
    "and encode all token lines into vocabulary indices. Index 0 is\n",
    "reserved for padding; index 1 is the `<UNK>` value for Out-Of-Vocabulary\n",
    "tokens (not used in our case since the input vocabulary is of fixed\n",
    "size); the last index (least frequent position) is dedicated to encode\n",
    "the special `[mask]` token. Each sample will have a fixed token length\n",
    "of 256. The average number of tokens per sample across the training\n",
    "set is 131. Short samples are padded with 0 and too large ones are\n",
    "simply truncated. The same operations will be applied to the test set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "xaKaJk5wr4LD",
    "outputId": "40c85f5c-eb40-4579-e25e-03487cde4e9d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                              tokens\n",
      "0  # include < id . id > # include < id . id > in...\n",
      "1  # include < id . id > int main ( void ) { int ...\n",
      "2  # include < id . id > # include < id . id > in...\n",
      "3  # include < id . id > int id ( int id , int id...\n",
      "4  long long id [ number ] [ number ] , id , id ;...\n"
     ]
    }
   ],
   "source": [
    "# Read all files and return content as list of lines.\n",
    "def get_text_list_from_files(files):\n",
    "    text_list = []\n",
    "    for name in files:\n",
    "        with open(name) as f:\n",
    "            for line in f:\n",
    "                text_list.append(line)\n",
    "    return text_list\n",
    "\n",
    "# Compose the full path names to the token files.\n",
    "# Creates and returns a dataframe with single key \"tokens\".\n",
    "def get_data_from_text_files(folder_name):\n",
    "    files = glob.glob(folder_name + '/*.toks')\n",
    "    texts = get_text_list_from_files(files)\n",
    "    df = pd.DataFrame({'tokens': texts})\n",
    "    df = df.sample(len(df)).reset_index(drop=True)\n",
    "    return df\n",
    "\n",
    "train_data = get_data_from_text_files('tokens/train')\n",
    "print(train_data.head())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rFNaRXWmsOwp"
   },
   "source": [
    "Let's collect all configuration parameters in a one place:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "jl8CauE9sWLG"
   },
   "outputs": [],
   "source": [
    "@dataclass\n",
    "class Config:\n",
    "    MAX_LEN = 256               # length of each input sample in tokens\n",
    "    BATCH_SIZE = 32             # batch size\n",
    "    LR = 0.001                  # learning rate\n",
    "    VOCAB_SIZE = 256            # max. number of words in vocabulary\n",
    "    EMBED_DIM = 128             # word embedding vector size\n",
    "    NUM_HEAD = 8                # number of attention heads (BERT)\n",
    "    FF_DIM = 128                # feedforward dimension (BERT)\n",
    "    NUM_LAYERS = 1              # number of BERT module layers\n",
    "\n",
    "config = Config()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "SdOiGnv7shXC"
   },
   "source": [
    "We use the Keras TextVectorization layer to process all data and extract a vocabulary of tokens to which we add the special `[mask]` token:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "wGjbPQq8swbW",
    "outputId": "98adfc09-673b-4de0-bc08-70420c115100"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "vocabulary size: 203\n",
      "padding token vocab[0]: \"\"\n",
      "OOV token vocab[1]: \"[UNK]\"\n",
      "mask token vocab[202]: \"[mask]\"\n"
     ]
    }
   ],
   "source": [
    "# No special text filtering.\n",
    "def custom_standardization(input_data):\n",
    "    return input_data\n",
    "\n",
    "# Create TextVectorization layer.\n",
    "def get_vectorize_layer(texts, vocab_size, max_seq):\n",
    "    vectorize_layer = TextVectorization(\n",
    "        max_tokens=vocab_size,\n",
    "        output_mode='int',\n",
    "        standardize=custom_standardization,\n",
    "        output_sequence_length=max_seq,\n",
    "    )\n",
    "    # Create vocabulary over all texts:\n",
    "    vectorize_layer.adapt(texts)\n",
    "    # Insert special mask token in vocabulary:\n",
    "    vocab = vectorize_layer.get_vocabulary()\n",
    "    vocab = vocab[2:len(vocab)-1] + ['[mask]']\n",
    "    vectorize_layer.set_vocabulary(vocab)\n",
    "    return vectorize_layer\n",
    "\n",
    "vectorize_layer = get_vectorize_layer(\n",
    "    train_data.tokens.values.tolist(),\n",
    "    config.VOCAB_SIZE,\n",
    "    config.MAX_LEN,\n",
    ")\n",
    "\n",
    "vocab = vectorize_layer.get_vocabulary()\n",
    "print('vocabulary size:', len(vocab))\n",
    "print('padding token vocab[0]: \"%s\"' % vocab[0])\n",
    "print('OOV token vocab[1]: \"%s\"' % vocab[1])\n",
    "print('mask token vocab[%d]: \"%s\"' % (len(vocab)-1, vocab[len(vocab)-1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "ehiNdjVktMV9",
    "outputId": "dfadd95e-e056-4fd7-cd41-5b2e856335e3"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mask_token_id: 202\n",
      "x_all_tokens.shape: (50000, 256)\n"
     ]
    }
   ],
   "source": [
    "# Encode the token strings to int vocab indices.\n",
    "def encode(texts):\n",
    "    encoded_texts = vectorize_layer(texts)\n",
    "    return encoded_texts.numpy()\n",
    "\n",
    "# Get mask token id for masked language model\n",
    "mask_token_id = encode(['[mask]'])[0][0]\n",
    "print('mask_token_id:', mask_token_id) # (always last index in vocab)\n",
    "\n",
    "# Randomly replace tokens by the [mask] and keep replaced token as label.\n",
    "def get_masked_input_and_labels(encoded_texts):\n",
    "    # These numbers come from something called the \"BERT recipe\":\n",
    "    # 15% used for prediction. 80% of that is masked. 10% is random token,\n",
    "    # 10% is just left as is.\n",
    "\n",
    "    # 15% masking:\n",
    "    inp_mask = np.random.rand(*encoded_texts.shape) < 0.15\n",
    "    # Do not mask special tokens:\n",
    "    inp_mask[encoded_texts < 2] = False\n",
    "    # Set targets to -1 by default, it means ignore:\n",
    "    labels = -1 * np.ones(encoded_texts.shape, dtype=int)\n",
    "    # Set golden labels for the masked tokens:\n",
    "    labels[inp_mask] = encoded_texts[inp_mask]\n",
    "    # False positions -> -1, True -> encoded word (vocab index)\n",
    "\n",
    "    # Prepare input\n",
    "    encoded_texts_masked = np.copy(encoded_texts)\n",
    "    # Set input to [mask] for 90% of tokens (leaving 10% unchanged):\n",
    "    inp_mask_2mask = inp_mask & (np.random.rand(*encoded_texts.shape) < 0.90)\n",
    "    encoded_texts_masked[inp_mask_2mask] = mask_token_id\n",
    "\n",
    "    # Set 10% to a random token\n",
    "    inp_mask_2random = inp_mask_2mask & (np.random.rand(*encoded_texts.shape) < 1 / 9)\n",
    "    encoded_texts_masked[inp_mask_2random] = np.random.randint(\n",
    "        2, mask_token_id, inp_mask_2random.sum())\n",
    "\n",
    "    # Prepare sample_weights to pass to .fit() method:\n",
    "    sample_weights = np.ones(encoded_texts.shape)\n",
    "    sample_weights[labels == -1] = 0\n",
    "\n",
    "    # y_labels would be same as encoded_texts, i.e., input tokens\n",
    "    y_labels = np.copy(encoded_texts)\n",
    "\n",
    "    return encoded_texts_masked, y_labels, sample_weights\n",
    "\n",
    "# Prepare data for masked language model\n",
    "\n",
    "# Encoding step:\n",
    "x_all_tokens = encode(train_data.tokens.values)\n",
    "print('x_all_tokens.shape:', x_all_tokens.shape)\n",
    "\n",
    "# Masking step:\n",
    "x_masked_train, y_masked_labels, sample_weights = get_masked_input_and_labels(\n",
    "    x_all_tokens\n",
    ")\n",
    "\n",
    "mlm_ds = (\n",
    "    tf.data.Dataset.from_tensor_slices(\n",
    "        (x_masked_train, y_masked_labels, sample_weights))\n",
    "    .shuffle(1000)\n",
    "    .batch(config.BATCH_SIZE)\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "4"
   },
   "source": [
    "## Model\n",
    "\n",
    "As mentioned above, the BERT-like model is copied from the Keras\n",
    "example \"End-to-end Masked Language Modeling with BERT\" by Ankur Singh, implemented in this [Jupyter Notebook](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/nlp/ipynb/masked_language_modeling.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "xDJIPbEptaYW",
    "outputId": "f7a544b4-25bb-4b84-cd95-b92ee5f2a15e"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model: \"masked_bert_model\"\n",
      "__________________________________________________________________________________________________\n",
      "Layer (type)                    Output Shape         Param #     Connected to                     \n",
      "==================================================================================================\n",
      "input_1 (InputLayer)            [(None, 256)]        0                                            \n",
      "__________________________________________________________________________________________________\n",
      "word_embedding (Embedding)      (None, 256, 128)     32768       input_1[0][0]                    \n",
      "__________________________________________________________________________________________________\n",
      "tf.__operators__.add (TFOpLambd (None, 256, 128)     0           word_embedding[0][0]             \n",
      "__________________________________________________________________________________________________\n",
      "encoder_0/multiheadattention (M (None, 256, 128)     66048       tf.__operators__.add[0][0]       \n",
      "                                                                 tf.__operators__.add[0][0]       \n",
      "                                                                 tf.__operators__.add[0][0]       \n",
      "__________________________________________________________________________________________________\n",
      "encoder_0/att_dropout (Dropout) (None, 256, 128)     0           encoder_0/multiheadattention[0][0\n",
      "__________________________________________________________________________________________________\n",
      "tf.__operators__.add_1 (TFOpLam (None, 256, 128)     0           tf.__operators__.add[0][0]       \n",
      "                                                                 encoder_0/att_dropout[0][0]      \n",
      "__________________________________________________________________________________________________\n",
      "encoder_0/att_layernormalizatio (None, 256, 128)     256         tf.__operators__.add_1[0][0]     \n",
      "__________________________________________________________________________________________________\n",
      "encoder_0/ffn (Sequential)      (None, 256, 128)     33024       encoder_0/att_layernormalization[\n",
      "__________________________________________________________________________________________________\n",
      "encoder_0/ffn_dropout (Dropout) (None, 256, 128)     0           encoder_0/ffn[0][0]              \n",
      "__________________________________________________________________________________________________\n",
      "tf.__operators__.add_2 (TFOpLam (None, 256, 128)     0           encoder_0/att_layernormalization[\n",
      "                                                                 encoder_0/ffn_dropout[0][0]      \n",
      "__________________________________________________________________________________________________\n",
      "encoder_0/ffn_layernormalizatio (None, 256, 128)     256         tf.__operators__.add_2[0][0]     \n",
      "__________________________________________________________________________________________________\n",
      "mlm_cls (Dense)                 (None, 256, 256)     33024       encoder_0/ffn_layernormalization[\n",
      "==================================================================================================\n",
      "Total params: 165,376\n",
      "Trainable params: 165,376\n",
      "Non-trainable params: 0\n",
      "__________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "def bert_module(query, key, value, i):\n",
    "    # Multi headed self-attention\n",
    "    att_out = MultiHeadAttention(\n",
    "        num_heads=config.NUM_HEAD,\n",
    "        key_dim=config.EMBED_DIM // config.NUM_HEAD,\n",
    "        name='encoder_{}/multiheadattention'.format(i))(query, key, value)\n",
    "    att_out = Dropout(0.1, name='encoder_{}/att_dropout'.format(i))(att_out)\n",
    "    att_out = LayerNormalization(\n",
    "        epsilon=1e-6,\n",
    "        name='encoder_{}/att_layernormalization'.format(i))(query + att_out)\n",
    "\n",
    "    # Feed-forward layer\n",
    "    ffn = Sequential([\n",
    "            Dense(config.FF_DIM, activation='relu'),\n",
    "            Dense(config.EMBED_DIM)\n",
    "            ], name='encoder_{}/ffn'.format(i))\n",
    "    ffn_out = ffn(att_out)\n",
    "    ffn_out = Dropout(0.1, name='encoder_{}/ffn_dropout'.format(i))(ffn_out)\n",
    "    sequence_output = LayerNormalization(\n",
    "        epsilon=1e-6,\n",
    "        name='encoder_{}/ffn_layernormalization'.format(i))(att_out + ffn_out)\n",
    "    return sequence_output\n",
    "\n",
    "def get_pos_encoding_matrix(max_len, d_emb):\n",
    "    pos_enc = np.array(\n",
    "        [\n",
    "            [pos / np.power(10000, 2 * (j // 2) / d_emb) for j in range(d_emb)]\n",
    "            if pos != 0\n",
    "            else np.zeros(d_emb)\n",
    "            for pos in range(max_len)\n",
    "        ]\n",
    "    )\n",
    "    # 0::2 means start at 0 and step 2 (all even)\n",
    "    pos_enc[1:, 0::2] = np.sin(pos_enc[1:, 0::2])  # dim 2i\n",
    "    pos_enc[1:, 1::2] = np.cos(pos_enc[1:, 1::2])  # dim 2i+1\n",
    "    return pos_enc\n",
    "\n",
    "loss_fn = losses.SparseCategoricalCrossentropy(\n",
    "    reduction=losses.Reduction.NONE\n",
    "    )\n",
    "loss_tracker = metrics.Mean(name='loss')\n",
    "\n",
    "class MaskedLanguageModel(Model):\n",
    "    def train_step(self, inputs):\n",
    "        if len(inputs) == 3:\n",
    "            features, labels, sample_weight = inputs\n",
    "        else:\n",
    "            features, labels = inputs\n",
    "            sample_weight = None\n",
    "\n",
    "        with tf.GradientTape() as tape:\n",
    "            predictions = self(features, training=True)\n",
    "            loss = loss_fn(labels, predictions, sample_weight=sample_weight)\n",
    "\n",
    "        # Compute gradients:\n",
    "        trainable_vars = self.trainable_variables\n",
    "        gradients = tape.gradient(loss, trainable_vars)\n",
    "\n",
    "        # Update weights:\n",
    "        self.optimizer.apply_gradients(zip(gradients, trainable_vars))\n",
    "\n",
    "        # Compute our own metrics:\n",
    "        loss_tracker.update_state(loss, sample_weight=sample_weight)\n",
    "\n",
    "        # Return a dict mapping metric names to current value\n",
    "        return {'loss': loss_tracker.result()}\n",
    "\n",
    "    @property\n",
    "    def metrics(self):\n",
    "        # We list our `Metric` objects here so that `reset_states()` can be\n",
    "        # called automatically at the start of each epoch\n",
    "        # or at the start of `evaluate()`.\n",
    "        # If you don't implement this property, you have to call\n",
    "        # `reset_states()` yourself at the time of your choosing.\n",
    "        return [loss_tracker]\n",
    "\n",
    "def create_masked_language_bert_model():\n",
    "    inputs = Input((config.MAX_LEN,), dtype=tf.int64)\n",
    "\n",
    "    word_embeddings = Embedding(\n",
    "        input_dim=config.VOCAB_SIZE,\n",
    "        output_dim=config.EMBED_DIM,\n",
    "        name='word_embedding')(inputs)\n",
    "\n",
    "    position_embeddings = Embedding(\n",
    "        input_dim=config.MAX_LEN,\n",
    "        output_dim=config.EMBED_DIM,\n",
    "        weights=[get_pos_encoding_matrix(config.MAX_LEN, config.EMBED_DIM)],\n",
    "        name='position_embedding',\n",
    "        )(tf.range(start=0, limit=config.MAX_LEN, delta=1))\n",
    "\n",
    "    encoder_out = word_embeddings + position_embeddings\n",
    "\n",
    "    for i in range(config.NUM_LAYERS):\n",
    "        encoder_out = bert_module(encoder_out, encoder_out, encoder_out, i)\n",
    "\n",
    "    mlm_output = Dense(config.VOCAB_SIZE, name='mlm_cls',\n",
    "                       activation='softmax')(encoder_out)\n",
    "    mlm_model = MaskedLanguageModel(inputs, mlm_output,\n",
    "                                    name='masked_bert_model')\n",
    "    optimizer = optimizers.Adam(learning_rate=config.LR)\n",
    "    mlm_model.compile(optimizer=optimizer)\n",
    "    return mlm_model\n",
    "\n",
    "mlm_model = create_masked_language_bert_model()\n",
    "mlm_model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5"
   },
   "source": [
    "## Training\n",
    "\n",
    "The model is trained with 50,000 samples in batches of 32 (1563\n",
    "batches per epoch) over 5 epochs with a learning rate of 0.001 using the\n",
    "Adam optimizer."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 528
    },
    "id": "Giycf1_xtnX8",
    "outputId": "829e91cd-2275-4e9c-9ecf-a478b3117378"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train BERT MLM model on Project CodeNet:\n",
      "Epoch 1/5\n",
      "1563/1563 [==============================] - 465s 297ms/step - loss: 1.9673\n",
      "Epoch 2/5\n",
      "1563/1563 [==============================] - 497s 318ms/step - loss: 0.7906\n",
      "Epoch 3/5\n",
      "1563/1563 [==============================] - 493s 316ms/step - loss: 0.5638\n",
      "Epoch 4/5\n",
      "1563/1563 [==============================] - 495s 317ms/step - loss: 0.4734\n",
      "Epoch 5/5\n",
      "1563/1563 [==============================] - 463s 296ms/step - loss: 0.4187\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVUAAAFACAYAAAAMF+8GAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjQuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Z1A+gAAAACXBIWXMAAAsTAAALEwEAmpwYAAAxV0lEQVR4nO3de1hUBf4/8PcwAwx3GUZAQFyvCd5QUdRVF3LEtrzQTzQ3s4u5RVhGWmuSmmYmq2KsKV8tTbfLlpmadldSUxetDCHFTaUsNVDioiD3mTm/P4iREXS4nJk5M/N+PU/P03DOzLwd492Zc/vIBEEQQEREonCydgAiInvCUiUiEhFLlYhIRCxVIiIRsVSJiETEUiUiEhFLlazq4MGDkMlkuHTpUqueJ5PJ8M4775gplfXei2yfwtoByDbIZLLbLu/SpQt++eWXVr/uiBEjUFBQAH9//1Y9r6CgAB06dGj1+xGZG0uVWqSgoMDw75mZmZg8eTKysrLQqVMnAIBcLjdav7a2Fi4uLiZf18XFBYGBga3O05bnEFkCv/5TiwQGBhr+UalUAICOHTsafubv74+1a9fi/vvvh4+PD2bMmAEAeOGFFxAWFgZ3d3d07twZCQkJuHbtmuF1b/763/B43759GD16NNzd3REeHo7PP//cKM/NX8llMhnS09MxY8YMeHl5ISQkBCtWrDB6TnFxMaZMmQIPDw8EBARg0aJFeOihh6DRaFr1WRQUFGDatGno0KED3NzcEB0djePHjxuW19XVYe7cuQgJCYGrqys6deqEadOmGZbn5uZi3Lhx6NChAzw8PBAWFoa33367VRlIuliqJJqlS5dixIgRyMrKwssvvwwAcHNzw+uvv47Tp09j69atOHjwIObMmWPytZ599lkkJycjJycHUVFRuO+++1BaWmry/UePHo3s7GwsWLAAycnJ+OqrrwzLH3nkEeTk5OCTTz7B/v37cenSJXz00Uet+jMKgoC4uDj8+OOP+OSTT/Dtt98iICAAY8eORVFREQDgtddewwcffIB33nkH586dw549ezBs2DDDa/ztb3+Dn58fMjMzcfLkSaxZswa+vr6tykESJhC10oEDBwQAwsWLFw0/AyDMnDnT5HN37twpuLi4CDqdrtnXani8Y8cOw3MuX74sABC++OILo/d7++23jR4/9dRTRu/Vu3dv4fnnnxcEQRDOnj0rABAyMjIMy2tra4WQkBBhzJgxt83c+L0yMjIEAEJubq5heXV1tRAYGCgsXbpUEARBmDNnjhATEyPo9fpmX8/b21vYsmXLbd+TbBe3VEk0Q4cObfKznTt3YvTo0QgKCoKnpyemT5+O2tpaXL58+bavFRERYfj3gIAAyOVyXLlypcXPAYCgoCDDc06fPg0ARluMzs7OiIyMvO1r3iw3Nxd+fn4IDw83/MzV1RVRUVHIzc0FUL9FfPLkSfTo0QMJCQnYsWMHamtrDes/++yzmDVrFqKjo7FkyRJkZWW1KgNJG0uVROPh4WH0+JtvvsGUKVMwevRo7Nq1C1lZWdiwYQMAGJVMc5o7yKXX61v1HJlM1uQ5ps5iEENERATOnz+P1atXw8XFBU8//TQiIiJQVlYGAFi0aBHOnj2LqVOn4tSpUxg2bBgWLlxo9lxkGSxVMpsjR45ArVbj5ZdfRlRUFHr16tXq81HF0rBlefToUcPPtFotvv/++1a9Tp8+fVBcXGzY8gWAmpoafPPNN+jbt6/hZ56enrj33nuxdu1aHD9+HP/73//w9ddfG5Z369YNiYmJ+PDDD/HSSy/h//7v/9r6RyOJ4SlVZDZ33HEHfv/9d2zevBkxMTE4cuQI0tPTrZKlZ8+emDBhAmbPno2NGzeiY8eOSE1NRVlZWau2Xu+8804MHToU999/P9avXw8fHx8sW7YM1dXVeOKJJwAAq1atQlBQECIiIuDu7o733nsPcrkcvXr1wvXr1zF//nxMnjwZXbt2xdWrV/HFF18Y7U4g28YtVTKb8ePH44UXXkBycjL69euH999/H6tWrbJani1btqBv377461//iujoaAQHB2Ps2LFQKpUtfg2ZTIaPPvoIvXv3xj333IMhQ4bg8uXL2LdvH9RqNQDA29sba9aswfDhw9GvXz/s2rULO3bswB133AGFQoHS0lI8+uijCAsLw7hx4xAQEID//Oc/5vpjk4XJBIF3/ifHpNPp0Lt3b0ycOBGpqanWjkN2gl//yWEcOnQIhYWFGDhwIMrLy/Hqq6/il19+wcMPP2ztaGRHWKrkMHQ6HV5++WXk5eXB2dkZffv2xYEDB9CvXz9rRyM7wq//REQi4oEqIiIRsVSJiETEUiUiEpFVD1Tl5+e3+jlqtdpwNyBrk0oWqeQApJNFKjkAZpFyDqDtWYKCgpr9ObdUiYhEZHJLtaioCOvXr8fVq1chk8mg0Whw9913G60jCAK2bNmCEydOwNXVFYmJiejWrZvZQhMRSZXJUpXL5ZgxYwa6deuGqqoqPP/88+jfvz9CQkIM65w4cQKXL1/G2rVrce7cOWzatAmvvPKKWYMTEUmRyVL19fU13JXczc0NwcHBKCkpMSrV48ePY/To0ZDJZOjVqxcqKipQWlrKu5kTtYMgCKiuroZerxftloVXrlxBTU2NKK9lDzmA22cRBAFOTk5QKpUt/jto1YGqwsJCnD9/Hj169DD6eUlJieFmEgDg5+eHkpISlipRO1RXV8PZ2RkKhXjHkxUKRZMhjdYglRyA6SxarRbV1dVwc3Nr2eu19I2rq6uRmpqKhx9+GO7u7i19mpGMjAxkZGQAAFJSUoyKuKUUCkWbnmcOUskilRyAdLJIJQfQ9ixXrlyBq6urWfJIgVRyALfPolAoIJPJWvx32KI/lVarRWpqKkaNGoWoqKgmy1UqldEpCcXFxYaJm41pNBqjyZVtOY3BHk7FsNccgHSySCUH0PYsNTU1om/NKRQKaLVaUV/TlnMALctSU1PT5O+wzadUCYKADRs2IDg4GOPHj292ncjISBw6dAiCIODs2bNwd3fnV38iG1dSUoKxY8di7NixiIiIwODBgw2PTY3DycnJwaJFi0y+x8SJE0XJmpmZiQcffFCU12ovk1uqZ86cwaFDhxAaGornnnsOQP2I3YbWjo2NxcCBA5GVlYU5c+bAxcUFiYmJ5k1NRGanUqmwb98+AEBqaio8PDyQkJBgWK7Vam/5tXnAgAEYMGCAyffYs2ePOGElxGSp9u7dGx988MFt15HJZJg1a5ZooW7l2DEXaLUyjBxp9rciomYkJSXB1dUVubm5iIyMxKRJk7B48WLU1NRAqVRizZo16NGjBzIzM7Fhwwa89dZbSE1NxW+//YYLFy7gt99+w6xZs/D4448DqB9zc+7cOWRmZmLNmjXw9fXFmTNn0L9/f7z22muQyWT46quvsHTpUri7u2PIkCH49ddf8dZbb90yY2lpKebNm4cLFy5AqVRi5cqVCA8Px9GjR7F48WIA9Z21c+dOVFRUIDExEWVlZdDpdFixYkWzuzhbQzp7iltg3TpP5OUpkJkJOPFaMCKrKCgowO7duyGXy1FeXo5du3ZBoVDg0KFD+Oc//4k33nijyXPy8vKwfft2VFRUYNSoUZg5c2aTU5ROnTqF/fv3IzAwEJMmTcJ3332H/v37Y/78+di5cydCQ0Nb9C04NTUVffv2xZtvvokjR47g6aefxr59+7Bhwwa88sorGDJkCCoqKuDq6op33nkH0dHReOqpp6DT6VBVVdXuz8emSjU+vgqzZyuRmemCkSNvv0+HyJ4sXuyN06ed2/06MpkMDbdQDg+vw0svlbX6NcaPH284gFZWVoakpCScP38eMpkMdXV1zT5nzJgxcHV1haurK9RqNX7//Xf4+/sbrRMREWE4+NOnTx9cvHgR7u7u6NKlC0JDQwEAcXFxeOedd26b79tvvzUU+8iRI1FaWory8nIMGTIES5cuxb333ou//vWvhuGM8+bNQ21tLcaNG2c0EbetbGp7b9y4Knh7C/jww7ad0kVE7df4lMpVq1ZhxIgR2L9/P7Zu3XrLk+gbnxoml8ubPdru4uJicp32ePLJJ7Fq1SpUV1cjLi4OeXl5GDZsGHbv3o3AwEA888wz2L59e7vfx6a2VN3cgMmT9di2TYlXXpHB3Z1DC8gxtGWLsjlin8pUXl6OwMBAADB57KUtunfvjl9//RUXL15E586dW3RgKyoqCjt37sQzzzyDzMxMqFQqeHl54ZdffkFYWBjCwsKQnZ2NvLw8KJVKdO7cGdOnT0dtbS1OnjyJKVOmtCuzTW2pAsD06XpUVjrhs89aPlaYiMzjiSeewIoVKxAbG2uW807d3NzwyiuvYPr06bjrrrvg4eEBb2/v2z5n7ty5OHnyJDQaDV555RWkpaUBADZt2oQ777wTGo0Gzs7OiImJQWZmJmJiYhAbG4s9e/aIcsDdqjOq2nI/VZVKjTvucEKXLjq8/36xGVK1nFROMJdKDkA6WaSSA2h7lsrKyjZfvXgrUjnpvjU5Kioq4OHhAUEQkJycjK5du+Kxxx6zaJbm/i7s5n6qTk7A5MlVOHLEBfn5NhefiFrp3XffxdixYxETE4Py8nLMmDHD2pFuyyZbafLkSgiCDLt28YAVkb177LHHsG/fPhw8eBDr1q1r8Y1NrMUmS7VrVx2GDKnB9u1u4IBtIpISmyxVoP6c1XPnnPHDD+0/d49Iiqx4uINu0pq/C5st1fHjq+DqKuDDD6X9VYCorZycnCRxUMnRabVaOLXiEk6bOk+1sQ4dBIwdW41du9ywaFEZGp03TGQXlEolqqurUVNTI9qd/11dXSVxx32p5ABun6Xxnf9bymZLFQCmTKnEJ5+44cABJcaNq7Z2HCJRyWQy0Q/KSOVUM6nkAMTPYrNf/wHgL3+pgVqt4y4AIpIMmy5VZ2cgLq4K+/YpUVIiztcjIqL2sOlSBep3AdTVybBnD7dWicj6bL5U+/TRIiysjneuIiJJsPlSlcmA+PhKnDjhgrw8aYy8JSLHZfOlCgD33lsFJyfeZ5WIrM8uSjUgQI+//KUGO3a4Qa+3dhoicmR2UapA/WWr+fkKZGbyKgAish67KdVx46rg5aXnLgAisiq7KVU3t/r7AXz6qRKVlTxnlYisw+Rlqunp6cjKyoKPjw9SU1ObLK+srMTatWtRXFwMnU6HCRMmICYmxixhTYmPr8J773ngs8+UiI9v/6hZIqLWMrmlGh0djeTk5Fsu/+KLLxASEoJVq1ZhyZIleOutt6x2Z52hQ2sRGqrlLgAishqTpRoeHg5PT89bLpfJZKiuroYgCKiuroanp2erbpMlJo5aISJra3fz3HXXXfjtt9/w+OOPY968eXjkkUesVqoAR60QkXW1+9Z/OTk56NKlCxYvXowrV65g2bJl6N27d7NTIDMyMpCRkQEASElJgVqtbn1gheK2z1OrgREj9Ni50wuLFysh0m0o25TFUqSSA5BOFqnkAJhFyjkA8bO0u1QPHDiAuLg4yGQyBAYGwt/fH/n5+ejRo0eTdTUaDTQajeFxW+5h2JJ7H06a5I758ztg//5rGDCgrtXvIWYWS5BKDkA6WaSSA2AWKecA2p7FbCOq1Wo1Tp48CQC4evUq8vPz4e/v396XbZcJEzhqhYisw+SWalpaGk6fPo3y8nIkJCRg6tSphqP7sbGxmDx5MtLT0zFv3jwAwPTp0+Ht7W3e1Cb4+AiIjeWoFSKyPJOlmpSUdNvlKpUKCxcuFCuPaOLjK/Hxxxy1QkSWZbfnHXHUChFZg92WKketEJE12G2pAhy1QkSWZ9elylErRGRpdl2qHLVCRJZm16UKcNQKEVmW3ZcqR60QkSXZfakCHLVCRJbjEKXKUStEZCkOUaoctUJEluIQpQrU7wKorHTCZ58prR2FiOyYw5QqR60QkSU4TKly1AoRWYJDtQtHrRCRuTlUqXbtqsOQITXYvt0NgmDtNERkjxyqVIH6A1bnzjnjhx+crR2FiOyQw5UqR60QkTk5XKk2HrVSW2vtNERkbxyuVIH6O1eVlspx4ADPWSUicTlkqXLUChGZi0OWqrNz/S0BOWqFiMTmkKUK1O8C4KgVIhKbw5YqR60QkTkoTK2Qnp6OrKws+Pj4IDU1tdl1cnNzsXXrVuh0Onh5eWHp0qWiBxVbw6iVZct8kJcnR48eOmtHIiI7YHJLNTo6GsnJybdcXlFRgU2bNmH+/PlYs2YN5s6dK2pAc+KoFSISm8lSDQ8Ph6en5y2XHzlyBFFRUVCr1QAAHx8f8dKZGUetEJHY2r1PtaCgANevX8eSJUswf/58fP3112LkshiOWiEiMZncp2qKTqfD+fPnsWjRItTW1mLhwoXo2bMngoKCmqybkZGBjIwMAEBKSoph67ZVgRWKNj3vVqZPBxYsEPDJJ76Ii2vdflWxs7SVVHIA0skilRwAs0g5ByB+lnaXqp+fH7y8vKBUKqFUKhEWFoZff/212VLVaDTQaDSGx0VFRa1+P7Va3abn3c499/hgxw43LF5cDHf3lt++yhxZ2kIqOQDpZJFKDoBZpJwDaHuW5joOEOHrf2RkJH788UfodDrU1NQgLy8PwcHB7X1Zi+KoFSISi8kt1bS0NJw+fRrl5eVISEjA1KlTodVqAQCxsbEICQlBREQEnn32WTg5OeHOO+9EaGio2YOLqfGolfj4KmvHISIbZrJUk5KSTL7IxIkTMXHiRDHyWEXDqJW0NE/k5zshKIinAhBR2zjsFVU346gVIhIDS/UPHLVCRGJgqTbCUStE1F4s1UY4aoWI2oul2ghHrRBRe7FUb8JRK0TUHizVm3DUChG1B0v1Jhy1QkTtwVJtBketEFFbsVSbwVErRNRWLNVmNIxaOXHCBXl5cmvHISIbwlK9BY5aIaK2YKneAketEFFbsFRvY8qUSo5aIaJWYaneRmxsNby89NwFQEQtxlK9DTc3YPz4Knz6qRKVlTxnlYhMY6mawFErRNQaLFUTGo9aISIyhaVqQsOolSNHXJCfz4+LiG6PLdECHLVCRC3FUm0BjlohopZiqbYQR60QUUuwVFuIo1aIqCVMlmp6ejpmzZqFefPm3Xa9vLw8TJs2DceOHRMtnJRw1AoRtYTJUo2OjkZycvJt19Hr9Xj33XcxYMAA0YJJEUetEJEpJks1PDwcnp6et13n888/R1RUFLy9vUULJkUctUJEprR7n2pJSQm+/fZbxMbGipFH0jhqhYhMUbT3BbZu3Yrp06fDycl0P2dkZCAjIwMAkJKSArVa3er3UygUbXqeWP7+dxneeEOG/fs7IjzcyapZGlj7M2lMKlmkkgNgFinnAMTP0u5S/emnn/Cvf/0LAFBWVoYTJ07AyckJQ4cObbKuRqOBRqMxPC4qKmr1+6nV6jY9TyxBQUBYWEds3SogIUFv1SwNrP2ZNCaVLFLJATCLlHMAbc8SFBTU7M/bXarr1683+vfBgwc3W6j2omHUyrJlPjhzphZ+ftZORERSYvI7e1paGhYuXIj8/HwkJCRg//792Lt3L/bu3WuJfJLUMGrl3Xc5v4qIjJncUk1KSmrxi82ePbs9WWxGw6iV//zHFU8+WX/TFSIigFdUtdmUKZW4eFHGUStEZISl2kaxsdXw9ua0VSIyxlJtIzc3YPJkPUetEJERlmo7PPCAnqNWiMgIS7UdRowQOGqFiIywVNuBo1aI6GZsgnbiqBUiaoyl2k4ctUJEjbFURcBRK0TUgKUqAo5aIaIGLFURcNQKETVgqYqEo1aICGCpioajVogIYKmKhqNWiAhgqYoqPr4SdXUy7NnDrVUiR8VSFVGfPlqEhdXxslUiB8ZSFVHDqJUTJ1yQl8epAESOiKUqsoZRK9xaJXJMLFWRNYxa2bHDDXq9tdMQkaWxVM1gypRK5OcrOGqFyAGxVM0gNrYaXl567gIgckAsVTNwcwPGj6/iqBUiB8RSNZMpU6o4aoXIASlMrZCeno6srCz4+PggNTW1yfLDhw9j9+7dEAQBbm5umDVrFv70pz+ZI6tNGTKk1jBqJT6+ytpxiMhCTG6pRkdHIzk5+ZbL/f39sWTJEqSmpmLy5Ml4/fXXRQ1oqzhqhcgxmfxtDw8Ph6en5y2X33HHHYblPXv2RHFxsXjpbFx8PEetEDkaUTeh9u/fj4EDB4r5kjbtT3/iqBUiR2Nyn2pLnTp1CgcOHMBLL710y3UyMjKQkZEBAEhJSYFarW71+ygUijY9zxxakuXhh50we7YCFy50xODB5mlWW/tMHCkHwCxSzgGIn0WUUv3111+xceNGLFiwAF5eXrdcT6PRQKPRGB4XFRW1+r3UanWbnmcOLckSEyODq2sgNm2qQZcuZVbLYSlSySKVHACzSDkH0PYsQUFBzf683V//i4qKsHr1ajz55JO3fBNHxlErRI7F5JZqWloaTp8+jfLyciQkJGDq1KnQarUAgNjYWHz44Ye4fv06Nm3aBACQy+VISUkxb2obEx9fiY8/dsOBA0qMG1dt7ThEZEYmSzUpKem2yxMSEpCQkCBWHrvUeNQKS5XIvvEESgvgqBUix8FStRCOWiFyDCxVC+GoFSLHwFK1EI5aIXIMLFUL4qgVIvvHUrUgjlohsn8sVQvjqBUi+8ZStTCOWiGybyxVC+OoFSL7xlK1Ao5aIbJfLFUraDxqhYjsC0vVCjhqhch+8TfaSjhqhcg+sVSthKNWiOwTS9WK4uOrcO6cM374wdnaUYhIJCxVK5owoQqurgI+/JB3riKyFyxVK+KoFSL7w1K1svj4SpSWynHgAM9ZJbIHLFUrazxqhYhsH0vVyjhqhci+sFQlgKNWiOwHS1UCOGqFyH6wVCWAo1aI7IfJUk1PT8esWbMwb968ZpcLgoA333wTTz31FJ599ln8/PPPood0BBy1QmQfTJZqdHQ0kpOTb7n8xIkTuHz5MtauXYvHHnsMmzZtEjWgo+CoFSL7YLJUw8PD4enpecvlx48fx+jRoyGTydCrVy9UVFSgtLRU1JCOgqNWiGxfu/eplpSUQK1WGx77+fmhpKSkvS/rkDhqhcj2KSz5ZhkZGcjIyAAApKSkGJVxSykUijY9zxzMkSU+XsAHH7hh40YFPDysl6OtpJJFKjkAZpFyDkD8LO0uVZVKhaKiIsPj4uJiqFSqZtfVaDTQaDSGx42f11JqtbpNzzMHc2SZMMEFW7ao8fbbFYiPr7JajraSShap5ACYRco5gLZnCQoKavbn7f76HxkZiUOHDkEQBJw9exbu7u7w9fVt78s6LI5aIbJtJrdU09LScPr0aZSXlyMhIQFTp06FVqsFAMTGxmLgwIHIysrCnDlz4OLigsTERLOHtmcNo1bS0jyRn++EoCCeCkBkS0yWalJS0m2Xy2QyzJo1S6w8hPoLAV591Qu7drlj9uzr1o5DRK3AK6okiKNWiGwXS1WiOGqFyDaxVCWKo1aIbBNLVaI4aoXINrFUJYyjVohsD0tVwqKjOWqFyNawVCVMoeCoFSJbw1KVOI5aIbItLFWJ46gVItvCUpU4jlohsi0sVRvAUStEtoOlagM4aoXIdrBUbQRHrRDZBpaqjeCoFSLbwFK1EW5uwPjxVfj0UyUqK3nOKpFUsVRtyJQpVaisdMJnn/GyVSKpYqnaEI5aIZI+lqoNaRi1cuSIC/Lz+VdHJEX8zbQx8fGVEAQZdu3i1iqRFLFUbQxHrRBJG0vVBnHUCpF0sVRtEEetEEkXS9UGcdQKkXQpWrJSdnY2tmzZAr1ejzFjxiAuLs5oeVFREdavX4+Kigro9Xrcf//9GDRokDny0h/i4yvx8cduOHBAienTrZ2GiBqY3FLV6/XYvHkzkpOT8eqrr+K///0vLl26ZLTOjh07MHz4cKxcuRJJSUnYvHmz2QJTPY5aIZImk6Wal5eHwMBABAQEQKFQYMSIEfjuu++M1pHJZKisrAQAVFZWwtfX1zxpyaDxqJXiYmunIaIGJku1pKQEfn5+hsd+fn4oKSkxWmfKlCk4fPgwEhISsGLFCsycOVP8pNREw6iV7du5a5xIKlq0T9WU//73v4iOjsaECRNw9uxZvPbaa0hNTYWTk/Eve0ZGBjIyMgAAKSkpUKvVrQ+sULTpeeZg7SzR0UC/fnq88YYcEyeqERRktSgG1v5MpJYDYBYp5wDEz2KyVFUqFYobfb8sLi6GSqUyWmf//v1ITk4GAPTq1Qt1dXUoLy+Hj4+P0XoajQYajcbwuKioqNWB1Wp1m55nDlLI8sQTSsyZ44uwMGc88kgFEhPLoVJZ76oAKXwmUsoBMIuUcwBtzxJ0i60Yk98bu3fvjoKCAhQWFkKr1SIzMxORkZFNQp06dQoAcOnSJdTV1cHb27vVIan1Jk2qxsmTdbjnnips2OCB4cMDsGaNJ8rLeXtAImswuaUql8sxc+ZMLF++HHq9HjExMejcuTO2bduG7t27IzIyEg8++CA2btyITz/9FACQmJgImYy/1JbSrRuwdu1VzJ59HatXeyE11RtvvumBJ5+8joceqoAbTxAgshiZIFjvCvL8/PxWP8cevjaYO0d2tjNWrvTC118rERioQ1JSOaZNq4SzBa5qlepnYk3MIt0cgBW+/pPtiYiow3/+U4Lt24sQEqLD8893QHS0P3budINOZ+10RPaNpWrHRoyoxUcfFeHf/y6Gu7uAp57yRWxsR3z5pZJ3uCIyE5aqnZPJAI2mBl9++TvS00tQUyPDzJkqTJigxuHDnMxKJDaWqoNwcqo/U+DgwUKkppbiyhUnTJumxtSpfvj+e95CkEgsLFUHo1AA06ZV4fDhQrz00jX8+KMCEyd2xCOP+OL0aVGuBSFyaCxVB6VUAo8+WoGjRwsxf34Zjh1zRWxsR8ye3QE//yy3djwim8VSdXAeHgLmzLmOo0evYPbs6/jySyWio/3xj3/44Lff+J8HUWvxt4YAAB06CFiwoByZmYV46KEKfPCBO0aNCsCSJd4oLuZ/JkQtxd8WMuLvr8eyZWU4fLgQcXFV2LzZA8OH+2PVKi+UlfEqOSJTWKrUrM6ddViz5ioOHPgdMTE1SEvzwvDhAUhP90RVFcuV6FZYqnRbPXposXFjKb78shCDBtVi+XJvjBjhj61b3Tkfi6gZLFVqkb59tXj77RLs2lWErl21eOGFDhg92h8ffMBLX4kaY6lSqwwdWosdO4rx7rvF6NBBj2ee8cWYMR3x6ae89JUIYKlSG8hk9YMHP/+8CK+/XgJBAB57TIW771Zj714Zy5UcGkuV2kwmA+65pxr79/+OV18tRUmJEyZMcEZ8vB++/Zb3FSDHxFKldpPLgalTq3DoUCHS0rT46ScF7r1XjRkzVDh1ipe+kmNhqZJoXF2BJ57QIzOzEMnJZcjKcsG4cf5ISPBFXh4vfSXHwFIl0bm7C5g9+zoyM68gKakcX33lipgYf8yd2wGXLrFcyb6xVMlsfHwEPPdcOY4eLcSjj1bgo4/cMGqUPxYt8sbvv/M/PbJP/C+bzE6t1mPJkjIcPnwFU6ZU4t//rr/0dcUKL1y9yquzyL6wVMligoP1WLnyGg4eLMS4cdVYt67+0te1az1RUcFyJfvAUiWL69ZNh/Xrr2Lv3kJERdXin/+sv/R182YP1NRYOx1R+7BUyWr69NFi69YS7N79O3r21GLxYh+MGuWP9993g1Zr7XREbdOikwizs7OxZcsW6PV6jBkzBnFxcU3WyczMxPbt2yGTydClSxc8/fTTYmclOxUZWYft24tx+LAL/vlPb8yb54v1673w3HNlGD++Gk78Xz/ZEJOlqtfrsXnzZixcuBB+fn5YsGABIiMjERISYlinoKAAH330EZYtWwZPT09cu3bNrKHJ/shkwOjRtRg1qghffqnEypVeeOIJFdatq8P8+WW4884ayLjblWyAyW2AvLw8BAYGIiAgAAqFAiNGjMB3331ntM5XX32FcePGwdPTEwDg4+NjnrRk92Qy4K67qrFv3+947bVSXL8uw4MP+iEuTo2jR3npK0mfyVItKSmBn5+f4bGfnx9KSkqM1snPz0dBQQEWLVqEF154AdnZ2aIHJccilwP/7/9V4euvC5GSchWXLskRH6/G/ferkJPDkdokXaJcmK3X61FQUIAXX3wRJSUlePHFF7F69Wp4eHgYrZeRkYGMjAwAQEpKCtRqdesDKxRtep45SCWLVHIA5snyzDNAQoIOGzYIWLXKFXffrURcnB5LlmgRFma5HG3FLNLNAYifxWSpqlQqFBcXGx4XFxdDpVI1Wadnz55QKBTw9/dHp06dUFBQgB49ehitp9FooNFoDI+LiopaHVitVrfpeeYglSxSyQGYN8uMGUBcnAxvvOGBjRs9sXu3MyZPrsK8eeUIDTW+U7ajfCatJZUsUskBtD1LUFBQsz83+fW/e/fuKCgoQGFhIbRaLTIzMxEZGWm0ztChQ5GbmwsAKCsrQ0FBAQICAlodksgULy8Bc+dex9GjhXj88Qp88okbRo/2R3KyD65c4WkCZH0mt1TlcjlmzpyJ5cuXQ6/XIyYmBp07d8a2bdvQvXt3REZGYsCAAcjJycEzzzwDJycnPPDAA/Dy8rJEfnJQKpUeixaV4e9/v45//csL777rjm3b3PDII5VITCyHRL5ZkgOSCYL17tOen5/f6ufYw9cGe80BWC/Lr7/KkZrqhZ073eDpKeDJJwVERpaiX786eHhYdxQB/36kmwMQ/+s/7yBMdqFLFx3Wrr2KxMTrWL3aCytWuAFQw8lJQK9eWgwYUIeIiFpERNShd+86uPDsLDITlirZld69tdi0qRQ6nRz795cjJ8cFOTnO2LvXFdu2uQMAXF0FhIffKNmIiDp066bllVskCpYq2aWAAGDs2BqMHVt/hxZBAC5elCM72xnZ2fVFu22bO7ZsqW9SLy89+vevL9qGrdqgID2v4qJWY6mSQ5DJgNBQHUJDdZg4sRoAoNMB584pkJPjjBMn6ov29dc9UVdX36QdO+qMdhsMGFALlYqjYun2WKrksOTy+t0FvXtrcd99VQCA6mrg9Gln5OTUb9FmZzvjq6+8IAj1Rduli/H+2X796uDuzqKlG1iqRI0olcCgQXUYNKgOQCUAoLxchh9+uFGy33/vjD173AAATk4C7rhDiwED6ncbDBxYfyDMmVfSOiyWKpEJXl4C/vznWvz5z7WGn/3+uxOys52Rk1NftF9+qcT779dflu3qKqBPnxv7Z2NiAF9f8ECYg2CpErVBx476Zg+EnTjhbDjj4P333fHmmw0HwgLRv38dBg6sL9oBA3ggzF6xVIlE0PhA2KRJxgfC8vJUOHKkBjk5ztiwwRNabX2T+vvrDAU7cGAd+vfngTB7wFIlMpOGA2EjR+oxfnz9jdsbHwhrOONg3z5vw3O6dNEadhsMHFiHvn15IMzWsFSJLKjxgbBHHqk/EFZWVn8grGH/7HffuWD37voLFRoOhN04f5YHwqSOpUpkZd7eAkaOrMXIkU0PhDVcqPDFF0q89179gTCl0viKsAEDatGtm44HwiSCpUokQc0dCLtwwfiKsPfeu3EgzNv7xhVhDUXbqRMPhFkDS5XIBshk9TeN6dLlxoEwrfbGFWEN59DefCCs8W6D/v1reUtEC2CpEtkohQIIC9MiLEyLadOMrwhrKNn6m8m4GZ6jUgkIDlajc2fdH/9oG/27jgfFRMBSJbIjxleE1Ws4EHbqlDOuXPHEuXN6nD2rwP79SlRXG+8f8POrL9eQEB1CQ7UICblRuCEhWri53fyOdDOWKpGda3wgTK12Q1FR/TRkQag/IHbxohyXLslx8aICFy7U//vp087Yu1eJ2lrj0u3Y8Ubh3ijf+sINDtZBqbTGn1BaWKpEDkomA/z99fD312Pw4Lomy/V6oLCwvnQvXlQYyvfCBQVyclzw2Wdywx29GgQG6v7YutUabeGGhuoQFKRziJuDs1SJqFlOTkBgoB6BgXoMGdK0dHU64PJlJ1y6pPijeG+U7/HjLtizRw6d7kbpymQCAgP1CA3VokcPOTp29DLaxRAUpIPCDhrJDv4IRGQNcjkQHKxHcHAtoqKaLtdqgcuX5bhwQf7HVu6N3QuHDzvh0iVP6PWyRq8noFMnndFBtIbCDQ3VITBQB7ncgn/ANmKpEpFZKBRASEj97oCbqdVq5OcXoaDAeAu34Z9Dh1xx5Yqb4T629a8nICjIuHBDQ2/sYggM1EviAgiWKhFZhYvLjXNvgdomy2tqgN9+k9+0e6G+gA8cUOLKFflNr3ejdI3PXKjfv9uxo2VKl6VKRJLk6gp066ZDt25Nt3QBoKrqRuk27FZo2OL98ksliorkN72egJCQpmctREcDPj7i5W5RqWZnZ2PLli3Q6/UYM2YM4uLiml3v2LFjWLNmDVasWIHu3buLl5KI6CZubkCPHjr06NF86VZWyv4oWnmTMxhycpxRWlpfui+9pMWjj4qXy2Sp6vV6bN68GQsXLoSfnx8WLFiAyMhIhISEGK1XVVWFzz//HD179hQvHRFRG7m7C+jVS4tevbTNLr9+vb50u3btIOr7mtzDkJeXh8DAQAQEBEChUGDEiBH47rvvmqy3bds2TJo0Cc68JxkR2QBPTwG9e2sRHCzu65rcUi0pKYGfn5/hsZ+fH86dO2e0zs8//4yioiIMGjQIe/bsueVrZWRkICMjAwCQkpICdRvu7qBQKNr0PHOQShap5ACkk0UqOQBmkXIOQPws7T5Qpdfr8dZbbyExMdHkuhqNBhqNxvC4qKio1e+nVqvb9DxzkEoWqeQApJNFKjkAZpFyDqDtWYKCgpr9uclSValUKC4uNjwuLi6GSqUyPK6ursbFixexdOlSAMDVq1excuVK/OMf/+DBKiJyOCZLtXv37igoKEBhYSFUKhUyMzMxZ84cw3J3d3ds3rzZ8HjJkiWYMWMGC5WIHJLJUpXL5Zg5cyaWL18OvV6PmJgYdO7cGdu2bUP37t0RGRlpiZxERDahRftUBw0ahEGDBhn97L777mt23SVLlrQ7FBGRrZLAlbJERPaDpUpEJCKWKhGRiFiqREQikgmCwPGJREQisbkt1eeff97aEQykkkUqOQDpZJFKDoBZmiOVHID4WWyuVImIpIylSkQkIpsr1cY3ZLE2qWSRSg5AOlmkkgNgluZIJQcgfhYeqCIiEpHNbakSEUmZJAf/paenIysrCz4+PkhNTW2yXBAEbNmyBSdOnICrqysSExPRrVs3q2TJzc3FypUr4e/vDwCIiopCfHy86DmKioqwfv16XL16FTKZDBqNBnfffbfROpb6XFqSxRKfS21tLV588UVotVrodDoMGzYMU6dONVqnrq4O69atw88//wwvLy8kJSUZMlk6y8GDB/H2228bbp151113YcyYMaJnAervc/z8889DpVI1Obptqc+kJVks+ZnMnj0bSqUSTk5OkMvlSElJMVou2u+PIEG5ubnCTz/9JMydO7fZ5d9//72wfPlyQa/XC2fOnBEWLFhgtSynTp0SVqxYYbb3b1BSUiL89NNPgiAIQmVlpTBnzhzh4sWLRutY6nNpSRZLfC56vV6oqqoSBEEQ6urqhAULFghnzpwxWueLL74QNm7cKAiCIBw5ckRYs2aN1bIcOHBA2LRpk1ne/2Yff/yxkJaW1uzfgaU+k5ZkseRnkpiYKFy7du2Wy8X6/ZHk1//w8HB4enrecvnx48cxevRoyGQy9OrVCxUVFSgtLbVKFkvx9fU1/F/Tzc0NwcHBKCkpMVrHUp9LS7JYgkwmg1KpBADodDrodDrIZDKjdY4fP47o6GgAwLBhw3Dq1CkIZjiM0JIsllJcXIysrKxbbvFZ6jNpSRYpEev3R5Jf/00pKSkxminj5+eHkpIS+Pr6WiXP2bNn8dxzz8HX1xczZsxA586dzfp+hYWFOH/+PHr06GH0c2t8LrfKAljmc9Hr9Zg/fz4uX76McePGNZnm23jGmlwuh7u7O8rLy+Ht7W3xLADwzTff4H//+x86deqEhx56yCxzmrZu3YoHHngAVVVVzS635GdiKgtgmc+kwfLlywEAY8eObXLUX6zfH5ssVSnp2rUr0tPToVQqkZWVhVWrVmHt2rVme7/q6mqkpqbi4Ycfhru7u9nep71ZLPW5ODk5YdWqVaioqMDq1atx4cIFhIaGiv4+YmQZPHgw/vznP8PZ2Rn79u3D+vXr8eKLL4qa4fvvv4ePjw+6deuG3NxcUV/bHFks8Zk0WLZsGVQqFa5du4aXX34ZQUFBCA8PF/19JPn13xSVSmU0qOvmuVmW5O7ubvjaN2jQIOh0OpSVlZnlvbRaLVJTUzFq1ChERUU1WW7Jz8VUFkt+LgDg4eGBPn36IDs72+jnjWes6XQ6VFZWwsvLy2w5bpfFy8vLMMJ9zJgx+Pnnn0V/7zNnzuD48eOYPXs20tLScOrUqSb/M7PUZ9KSLJb4TBo0/C74+PhgyJAhyMvLa7JcjN8fmyzVyMhIHDp0CIIg4OzZs3B3d7faV/+rV68a9kfl5eVBr9eb5T9QQRCwYcMGBAcHY/z48c2uY6nPpSVZLPG5lJWVoaKiAkD90fcffvgBwTcNcR88eDAOHjwIADh27Bj69Oljln2dLcnSeP/c8ePHERISInqO+++/Hxs2bMD69euRlJSEvn37Gs2UAyz3mbQkiyU+E6D+W1XDLojq6mr88MMPTb7RiPX7I8mv/2lpaTh9+jTKy8uRkJCAqVOnQqvVAgBiY2MxcOBAZGVlYc6cOXBxcWnReGxzZTl27Bj27t0LuVwOFxcXJCUlmeU/0DNnzuDQoUMIDQ3Fc889BwD429/+Zvg/qyU/l5ZkscTnUlpaivXr10Ov10MQBAwfPhyDBw82mp925513Yt26dXjqqafg6emJpKQkUTO0Jsvnn3+O48ePQy6Xw9PT06z/3d7MGp9JS7JY6jO5du0aVq9eDaB+63zkyJGIiIjA3r17AYj7+8MrqoiIRGSTX/+JiKSKpUpEJCKWKhGRiFiqREQiYqkSEYmIpUpEJCKWKhGRiFiqREQi+v99/JRwtQwdugAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 864x360 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "print('Train BERT MLM model on Project CodeNet:')\n",
    "history = mlm_model.fit(mlm_ds, epochs=5)\n",
    "#mlm_model.save('bert_mlm_codenet.h5')\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "plt.style.use('ggplot')\n",
    "\n",
    "def plot_history(history):\n",
    "    loss = history.history['loss']\n",
    "    x = range(1, len(loss) + 1)\n",
    "\n",
    "    plt.figure(figsize=(12, 5))\n",
    "    plt.subplot(1, 2, 1)\n",
    "    plt.plot(x, loss, 'b', label='Training loss')\n",
    "    plt.title('Training loss')\n",
    "    plt.legend()\n",
    "    plt.show()\n",
    "\n",
    "plot_history(history)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6"
   },
   "source": [
    "## Evaluation\n",
    "\n",
    "We evaluate the trained model on a test set of 5,000 samples taken\n",
    "from problems not considered for the training set. Each sample is\n",
    "preprocessed in the same way as the training samples and one token\n",
    "(never a padding!) is arbitrarily replaced by the `[mask]`. Then a\n",
    "prediction is generated and the top 1 and top 5 results are compared\n",
    "with the expected value. The achieved accuracies are printed in the end.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "5Rbd460Tt4fI",
    "outputId": "af5f397a-ab5b-470b-db97-0d33ef541c23"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                              tokens\n",
      "0  # include < id . id > int main ( void ) { int ...\n",
      "1  # include < id . id > int main ( ) { int id , ...\n",
      "2  # include < id . id > int main ( ) { char id [...\n",
      "3  # include < id . id > int main ( void ) { int ...\n",
      "4  # include < id . id > int main ( void ) { char...\n",
      "number of test samples: 5000\n",
      "top-1 accuracy: 0.9248\n",
      "top-5 accuracy: 0.9958\n"
     ]
    }
   ],
   "source": [
    "# Load pretrained bert model\n",
    "#from tensorflow import keras\n",
    "#mlm_model = keras.models.load_model('bert_mlm_codenet.h5',\n",
    "#    custom_objects={'MaskedLanguageModel': MaskedLanguageModel})\n",
    "mlm_model.trainable = False\n",
    "\n",
    "# token<->id mappings as dicts:\n",
    "id2token = dict(enumerate(vocab))\n",
    "token2id = {y: x for x, y in id2token.items()}\n",
    "\n",
    "# Turns text into list of vocabulary indices.\n",
    "def prep(text):\n",
    "    R = [0] * config.MAX_LEN # all padding\n",
    "    text = text.split()\n",
    "    ntoks = len(text)\n",
    "    if ntoks > config.MAX_LEN:\n",
    "        ntoks = config.MAX_LEN\n",
    "        text = text[:ntoks]\n",
    "    # pick random position (never a padding):\n",
    "    k = random.randint(0, ntoks-1)\n",
    "    golden = 0\n",
    "    for i in range(len(text)):\n",
    "        w = text[i]\n",
    "        if w in token2id:\n",
    "            R[i] = token2id[w]\n",
    "        else:\n",
    "            R[i] = 1 # OOV: [UNK]\n",
    "        if i == k:\n",
    "            golden = R[i]\n",
    "            R[i] = mask_token_id\n",
    "    return k, golden, np.array(R)\n",
    "\n",
    "def predict(text):\n",
    "    mask_index, golden, R = prep(text)\n",
    "    sample = np.reshape(R, (1, config.MAX_LEN))\n",
    "    prediction = mlm_model.predict(sample)\n",
    "    # all substitute word probabilities:\n",
    "    mask_prediction = prediction[0][mask_index]\n",
    "    # word indices with top-k highest probabilities:\n",
    "    top_k = 5\n",
    "    top_indices = mask_prediction.argsort()[-top_k:][::-1]\n",
    "    # probabilities of the top_k\n",
    "    values = mask_prediction[top_indices]\n",
    "    correct_top1 = top_indices[0] == golden\n",
    "    correct_top5 = False\n",
    "    for i in range(len(top_indices)):\n",
    "        if top_indices[i] == golden:\n",
    "            correct_top5 = True\n",
    "            break\n",
    "    return correct_top1, correct_top5\n",
    "\n",
    "# Enumerate all test samples:\n",
    "test_data = get_data_from_text_files('tokens/test')\n",
    "print(test_data.head())\n",
    "\n",
    "correct_top1 = 0\n",
    "correct_top5 = 0\n",
    "num_tests = 0\n",
    "for test in test_data['tokens']:\n",
    "    # predict and check\n",
    "    top1, top5 = predict(test)\n",
    "    if top1:\n",
    "        correct_top1+=1\n",
    "    if top5:\n",
    "        correct_top5+=1\n",
    "    num_tests+=1\n",
    "\n",
    "print('number of test samples:', num_tests)\n",
    "print('top-1 accuracy:', correct_top1/num_tests);\n",
    "print('top-5 accuracy:', correct_top5/num_tests);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "7"
   },
   "source": [
    "## References\n",
    "\n",
    "> <a id=\"1\">[1]</a>\n",
    "Ankur Singh,\n",
    "[\"End-to-end Masked Language Modeling with BERT\"](https://keras.io/examples/nlp/masked_language_modeling)\n",
    "\n",
    "> <a id=\"2\">[2]</a>\n",
    "[CodeXGLUE -- Code Completion (token level)](https://github.com/microsoft/CodeXGLUE/tree/main/Code-Code/CodeCompletion-token)\n",
    "\n",
    "> <a id=\"3\">[3]</a>\n",
    "[AllenNLP Demo on Masked Language Modeling](https://demo.allennlp.org/masked-lm)\n",
    "\n",
    "> <a id=\"4\">[4]</a>\n",
    "Zhangyin Feng, Daya Guo, Duyu Tang, Nan Duan, Xiaocheng Feng, Ming Gong,\n",
    "Linjun Shou, Bing Qin, Ting Liu, Daxin Jiang, Ming Zhou,\n",
    "[CodeBERT: A Pre-Trained Model for Programming and Natural Languages](https://arxiv.org/abs/2002.08155)"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "Project_CodeNet_MLM.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
